#!/usr/bin/python3
# ******************************************************************************
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
# licensed under the Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#     http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN 'AS IS' BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
# PURPOSE.
# See the Mulan PSL v2 for more details.
# ******************************************************************************/
import re
import os
from collections import defaultdict
from typing import Tuple, List, Set, Optional

from ceres.conf.constant import CommandExitCode, TaskExecuteRes
from ceres.function.log import LOGGER
from ceres.function.util import execute_shell_command
from ceres.function.check import PreCheck
from ceres.function.status import SUCCESS, COMMAND_EXEC_ERROR


class CveFixTaskType:
    HOTPATCH = "hotpatch"
    COLDPATCH = "coldpatch"


class CveFixManage:
    def cve_fix(self, task_info: dict) -> dict:
        """
        fix cves by upgrading packages

        Args:
            task_info(dict): cve info which need to fix and check_items,e.g.
            {
                "fix_type": "coldpatch",
                "check_items": [],
                "rpms": [
                    {
                        "installed_rpm": "xxxxx",
                        "available_rpm": "unzip-6.0-50.oe2203.x86_64",
                    }
                ],
                "accepted": False,
            }

        Returns:
            dict: cve fix result e.g
                {
                    "check_items":[
                        {
                            "item":"network",
                            "result":true,
                            "log":"xxxx"
                        }
                    ],
                    "rpms":[
                        {
                            "installed_rpm":"kernel-4.19xxx",
                            "result": "succeed",
                            "log": "fix succeed"
                        }
                    ],
                    "dnf_event_start": 1,
                    "dnf_event_end": 5,
                    "status": succeed
                }
        """
        result = {}

        rpms = [rpm.get("available_rpm") for rpm in task_info["rpms"]]
        check_result, items_check_log = PreCheck.execute_check(task_info["check_items"])
        result["check_items"] = items_check_log
        if not check_result:
            LOGGER.warning("The pre-check is failed before execute command!")
            result["rpms"] = [
                {
                    "available_rpm": rpm,
                    "result": TaskExecuteRes.FAIL,
                    "log": "pre-check items check failed",
                }
                for rpm in rpms
            ]
            result["status"] = TaskExecuteRes.FAIL
            return result

        if task_info["fix_type"] == CveFixTaskType.COLDPATCH:
            result["status"], result["rpms"] = self._update_coldpatch_by_dnf_plugin(rpms)
        else:
            # The implementation of the hotpatch upgrade and rollback plan relies on the dnf transaction,
            # so the dnf transaction ID information needs to be returned after the repair is completed.
            result["dnf_event_start"] = self._query_latest_dnf_transaction_id()
            result["status"], result["rpms"], transaction_count = self._update_hotpatch_by_dnf_plugin(
                rpms, task_info["accepted"]
            )
            result["dnf_event_end"] = self._query_latest_dnf_transaction_id()
            if result["dnf_event_end"] - result["dnf_event_start"] != transaction_count:
                result["dnf_event_start"] = result["dnf_event_end"] = None
        return result

    def _update_coldpatch_by_dnf_plugin(self, rpms: List[str]) -> Tuple[str, list]:
        """
        update rpm of list and return their upgrade log

        Args:
            rpms(list): List of packages that need to be upgraded

        Returns:
            Tuple[str, List[dict]]
            a tuple containing two elements (update result, Information about each package upgrade log).
        """

        def gen_fail_result(rpms: List[str], log: str):
            return [
                {
                    "available_rpm": rpm,
                    "result": TaskExecuteRes.FAIL,
                    "log": log,
                }
                for rpm in rpms
            ]

        status, fixable_cve_info = self._query_fixable_cve_info()
        if status != SUCCESS:
            return TaskExecuteRes.FAIL, gen_fail_result(
                rpms, "Execution of CVE comparison failed due to failure to query fixable CVE information."
            )
        status, fixed_cve_info = self._query_fixed_cve_info_by_hotpatch()
        if status != SUCCESS:
            return TaskExecuteRes.FAIL, gen_fail_result(
                rpms, "Execution of CVE comparison failed due to failure to query fixed CVE information."
            )

        final_fix_result, package_update_info = TaskExecuteRes.SUCCEED, []
        for rpm in rpms:
            rpm_fix_info = {"available_rpm": rpm, "result": TaskExecuteRes.SUCCEED, "log": ""}
            compare_result, log = self.compare_cve(rpm, fixable_cve_info, fixed_cve_info)
            if compare_result:
                rpm_fix_info["result"], rpm_fix_info["log"] = self.__update_coldpatch(rpm)
            else:
                rpm_fix_info["result"] = TaskExecuteRes.FAIL
                rpm_fix_info["log"] = log

            if rpm_fix_info["result"] == TaskExecuteRes.FAIL:
                final_fix_result = TaskExecuteRes.FAIL

            package_update_info.append(rpm_fix_info)
        return final_fix_result, package_update_info

    def __update_coldpatch(self, rpm: str) -> Tuple[str, str]:
        """
        upgrade rpm by dnf plugin (coldpatch)

        Args:
            rpm(str): package that need to be upgraded

        Returns:
            Tuple[str, str]
            a tuple containing two elements (upgrade result, package upgrade log).
        """
        code, stdout, stderr = execute_shell_command([f"dnf upgrade-en {rpm} -y"])
        if code != CommandExitCode.SUCCEED:
            LOGGER.error(stderr)
            return TaskExecuteRes.FAIL, stdout + stderr
        elif rpm.rsplit("-", 2)[0] == "kernel":
            if not self.set_default_grub_kernel_version(rpm):
                return TaskExecuteRes.FAIL, stdout + stderr + "\nerror: set default kernel failed!"
        return TaskExecuteRes.SUCCEED, stdout + stderr

    def _update_hotpatch_by_dnf_plugin(self, rpms: List[str], accepted: bool) -> Tuple[str, list, int]:
        """
        upgrade rpm by dnf plugin (hotpatch)

        Args:
            rpms(list): List of packages that need to be upgraded

        Returns:
            Tuple[str, List[dict], int]
            a tuple containing three elements (update result, Information about each package upgrade, upgrade count).
        """
        upgrade_count = 0
        check_result, check_log = PreCheck.kernel_consistency_check()
        if not check_result:
            return (
                TaskExecuteRes.FAIL,
                [
                    {
                        "available_rpm": rpm,
                        "result": TaskExecuteRes.FAIL,
                        "log": f"kernel consistency check failed\n{check_log}",
                    }
                    for rpm in rpms
                ],
                upgrade_count,
            )

        final_fix_result, package_update_info = TaskExecuteRes.SUCCEED, []

        for rpm in rpms:
            code, stdout, stderr = execute_shell_command([f"dnf hotupgrade {rpm} -y"])
            tmp = {
                "available_rpm": rpm,
                "result": TaskExecuteRes.SUCCEED,
                "log": stdout + stderr,
            }
            if code != CommandExitCode.SUCCEED or "Apply hot patch succeed" not in stdout:
                tmp["result"] = TaskExecuteRes.FAIL
                final_fix_result = TaskExecuteRes.FAIL
            elif "Nothing to do" not in stdout:
                upgrade_count += 1

            if tmp["result"] == TaskExecuteRes.SUCCEED and accepted:
                try:
                    hotpatch_name = rpm.rsplit(".", 1)[0].split("-", 1)[1]
                    _, hotpatch_status_set_log = self._set_hotpatch_status_by_dnf_plugin(hotpatch_name, "accept")
                    tmp["log"] += f"\n\n{hotpatch_status_set_log}"
                except IndexError as error:
                    LOGGER.error(error)
                    tmp["log"] += f"\n\nhotpatch status set failed due to can't get correct hotpatch name!"
            package_update_info.append(tmp)
        return final_fix_result, package_update_info, upgrade_count

    @staticmethod
    def _query_fixable_cve_info() -> Tuple[str, dict]:
        """
        Query the CVEs fixed by the upgradeable version of each package

        Returns:
            Tuple[status, dict]
            a tuple containing two elements (status code, fixed_cve_info).

        Example:
            "Succeed", {"kernel": {
                "kernel-5.10.0-60.91.0.115.oe2203.x86_64": ["CVE-2023-1829"],
                "kernel-5.10.0-60.91.0.116.oe2203.x86_64": ["CVE-2023-2006"]
                }}
        """
        code, stdout, stderr = execute_shell_command(["dnf updateinfo list cves"])
        if code != CommandExitCode.SUCCEED:
            LOGGER.error("Failed to query update info by dnf!")
            LOGGER.error(stderr)
            return COMMAND_EXEC_ERROR, defaultdict()

        all_cve_info = re.findall(r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+)", stdout)
        rpm_update_info = defaultdict(lambda: defaultdict(list))
        for cve_id, _, coldpatch in all_cve_info:
            rpm_name = coldpatch.rsplit("-", 2)[0]
            rpm_update_info[rpm_name][coldpatch].append(cve_id)

        return SUCCESS, rpm_update_info

    @staticmethod
    def _query_fixed_cve_info_by_hotpatch() -> Tuple[str, dict]:
        """
        Statistics CVE data that will be fixed by hotpatch

        Returns:
            Tuple[status, dict]
            a tuple containing two elements (status code, fixed_cve_info).

        Example:
            "Succeed", {"kernel": {"CVE-2023-XXXX","CVE-2022-XXXX"}}
        """
        code, stdout, stderr = execute_shell_command(["dnf hot-updateinfo list cves --installed"])
        if code != CommandExitCode.SUCCEED:
            LOGGER.error("Failed to query fixed cves by hotpatch!")
            LOGGER.error(stderr)
            return COMMAND_EXEC_ERROR, set()

        hotpatch_fixed_info = defaultdict(set)
        all_cve_info = re.findall(r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+|-)\s+(patch\S+)", stdout)
        for cve_id, _, _, hotpatch in all_cve_info:
            rpm_name = hotpatch.rsplit("-", 5)[0][6:]
            hotpatch_fixed_info[rpm_name].add(cve_id)

        return SUCCESS, hotpatch_fixed_info

    def compare_cve(self, rpm: str, updated_info: dict, hotpatch_fixed_info: dict) -> Tuple[bool, str]:
        """
        Determine whether the packages to be upgraded covers the vulnerabilities fixed by the hotpatch

        Args:
            rpms(list): List of packages that need to be upgraded

        Returns:
            Tuple[bool, str]
            a tuple containing two elements (compare result, compare log).
        """
        compare_info = dict()
        upgraded_packages: set = self._query_upgraded_packages(rpm)
        if not upgraded_packages:
            return False, "Execution of CVE comparison failed due to failure to query upgraded_packages."
        for rpm in upgraded_packages:
            fixed_cve_by_coldpatch = set()
            rpm_name = rpm.rsplit("-", 2)[0]

            for update_rpm, cve_list in updated_info.get(rpm_name, {}).items():
                if rpm >= update_rpm:
                    fixed_cve_by_coldpatch.update(cve_list)
            cve_difference_set = hotpatch_fixed_info.get(rpm_name, set()) - fixed_cve_by_coldpatch
            if cve_difference_set:
                compare_info[rpm_name] = cve_difference_set

        if not compare_info.values():
            return True, ""

        log = (
            "After upgrading the package, vulnerabilities in the package or in its dependent software package "
            "may be re-exposed. \nHere are some specific vulnerabilities that could potentially re-exposed:\n"
        )
        for rpm_name, cve_info in compare_info.items():
            for cve_id in cve_info:
                log += f"{rpm_name}\t{cve_id}\n"
        return False, log

    @staticmethod
    def _query_upgraded_packages(package: str) -> Set[str]:
        """
        Resolve all packages to be upgraded and their dependencies, store them in a set and
        return it

        Args:
            packages(list): List of package that need to be upgraded

        Returns:
            set

        """
        package_set = set()
        if package.rsplit("-", 2)[0] == "kernel":
            package_set.add(package)
            return package_set

        # The exit code of the command is 1 when input parameters contains assumeno
        _, stdout, _ = execute_shell_command([f"dnf upgrade-en {package} --assumeno"])

        installed_rpm_info = re.findall(r"(Upgrading|Installing):(.*?)Transaction Summary", stdout, re.S)
        if not installed_rpm_info:
            return package_set

        installed_rpm_info_list = installed_rpm_info[0][1].strip().split("\n")
        for single_rpm_info in installed_rpm_info_list:
            # info_list example:
            # ['aops-ceres', 'aarch64', 'v1.3.4-5.oe2203sp2', '@commandline', '107 k]
            pkg_info_list = re.split(r'\s+', single_rpm_info.strip())
            if len(pkg_info_list) < 5:
                break
            package_set.add(f"{pkg_info_list[0]}-{pkg_info_list[2]}.{pkg_info_list[1]}")
        return package_set

    @staticmethod
    def set_default_grub_kernel_version(kernel_rpm_name: str) -> bool:
        """
        Set the boot kernel

        Args:
            kernel_rpm_name(str): The name of the installed kernel package

        Returns:
            bool
        """
        boot_kernel_path = os.path.join("/boot/", f"vmlinuz-{kernel_rpm_name[7:]}")
        if not os.path.exists(boot_kernel_path):
            LOGGER.error("Can't find target kernel in /boot when set default kernel")
            return False

        LOGGER.info("The Linux boot kernel is about to be changed")
        code, _, stderr = execute_shell_command([f"grubby --set-default={boot_kernel_path}"])

        if code != CommandExitCode.SUCCEED:
            LOGGER.info("The Linux boot kernel change failed")
            LOGGER.error(stderr)
            return False
        LOGGER.info("The Linux boot kernel change successful")
        return True

    @staticmethod
    def _query_latest_dnf_transaction_id() -> Optional[int]:
        """Query latest yum transaction id

        Returns:
            int
        """
        # Example of command execution result:
        # [root@localhost ~]# dnf history
        # ID   | Command line   | Date and time       | Action(s)     | Altered
        # ---------------------------------------------------------------------
        # 3    | rm aops-ceres  | 2023-11-30 09:57    | Removed       | 1
        # 2    | install gcc    | 2023-11-30 09:57    | Install       | 1
        code, stdout, stderr = execute_shell_command(
            ["dnf history", "grep -E '^\s*[0-9]+'", "head -1", "awk '{print $1}'"]
        )
        if code != CommandExitCode.SUCCEED:
            LOGGER.error(stderr)
            return None

        return int(stdout)

    @staticmethod
    def _set_hotpatch_status_by_dnf_plugin(hotpatch: str, operation: str) -> Tuple[bool, str]:
        """
        change hotpatch status by dnf plugin

        Args:
            hotpatch(str):  hotpatch name which you want to change its status
            operation(str): the action that needs to be performed on this hot patch.
                            supported actions: apply,deactive,remove,active,accept
        Returns:
            Tuple[bool, str]
            a tuple containing two elements (operation result, operation log).
        """

        # replace -ACC to /ACC or -SGL to /SGL
        # Example: kernel-5.10.0-153.12.0.92.oe2203sp2-ACC-1-1 >> kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1
        wait_to_remove_patch = re.sub(r'-(ACC|SGL)', r'/\1', hotpatch)
        # Example of command execution result:
        # Succeed:
        # [root@openEuler ~]# dnf hotpatch --remove kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1
        # Last metadata expiration check: 3:24:16 ago on Wed 13 Sep 2023 08:16:17 AM CST.
        # Gonna remove this hot patch: kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1
        # remove hot patch 'kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1' succeed
        # Fail:
        # [root@openEuler ~]# dnf hotpatch --accept kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1
        # Last metadata expiration check: 3:25:24 ago on Wed 13 Sep 2023 08:16:17 AM CST.
        # Gonna accept this hot patch: kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1
        # accept hot patch 'kernel-5.10.0-153.12.0.92.oe2203sp2/ACC-1-1' failed, remain original status
        code, stdout, stderr = execute_shell_command([f"dnf hotpatch --{operation} {wait_to_remove_patch}"])
        if code != CommandExitCode.SUCCEED or 'failed' in stdout:
            LOGGER.error(f"hotpatch {hotpatch} set status failed!")
            return False, stdout + stderr

        return True, stdout + stderr
